"""Stefan Spence 28.08.20
Demonstrate the effect of changing pulse duration during a RSC sequence"""
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import json
import sys
import os
os.chdir(os.path.dirname(__file__))
plt.style.use('../../vinstyle.mplstyle')
from qutip import *
from RSCHamiltonians import hbar, kB, simulate
import time

#### run simulation
n_ph = 15 # number of harmonic levels to include
n = tensor(num(n_ph), qeye(2)) # number operator
gs = tensor(fock_dm(n_ph, 0), qeye(2)) # motional ground state
T = 15 # initial temperature in uK
w  = 120*2*np.pi # trap frequency (2pi kHz)
Or = 25*2*np.pi # carrier Rabi frequency (2pi kHz)                                                                                                                                                                                                                               02.*2*np.pi  # Rabi freq (2pi kHz)
wr = 2.07*2*np.pi # Recoil frequency (2pi kHz)
eta = (wr/w)**0.5
dets = np.linspace(0.75, 1.15,45) # detunings to test (d/w)
n_ds = len(dets)
dets = dets[np.argsort(abs(np.arange(n_ds) - n_ds//2))] # start from middle
durs = np.linspace(0.15, 1, 50) # pulse durations to test (T / (pi/(eta*Or)))
final_n = np.zeros((len(dets), len(durs))) # store the final phonon number
pngs = np.zeros((len(dets), len(durs))) # probability not in ground state
# results = [] # store results of simulation in case it's useful to look back at
ii = 0
t0 = time.time()

#%% run simulations
# for i, d in enumerate(dets):
#     for j, dr in enumerate(durs):
#         tlist, result = simulate(Np=n_ph, W=w, D=w*d, LD=eta, LDop=eta, Temp=T,
#               Rabi=Or, Ntp=1500, dR=dr, N=30, pgb=None)
#         # results.append(result)
#         final_n[i, j] = np.real(expect(n, result.states[-1]))
#         pngs[i, j] = 1 - np.real(expect(gs, result.states[-1]))
#         print(i, j, final_n[i, j], end=',', flush=True)
        
#     n_complete = (i*len(durs) + j)
#     ave_t = (time.time() - t0) / n_complete
#     print('\nTime remaining: %s minutes\n'%(ave_t*(n_ds*len(durs) - n_complete)/60.))
#     # ###' save results
#     results_dict = {"T (uK)": T, "Trap Frequency (kHz)": w/2/np.pi, "Raman LD parameter": eta,
#     "OP LD parameter": eta, "Rabi Frequency (kHz)": Or/2/np.pi, "OP scattering rate (/s)": 100,
#     "Depump scattering rate (/s)": 0, "Number of OP photons": 3, "Number of pulses": 30,
#     "Number harmonic levels": n_ph, "Detuning d/w": list(dets), "Duration T/(pi/nOr)": list(durs), 
#     "Final Phonon Number": [list(fn) for fn in final_n],
#     "1 - Dark State Fraction": [list(fn) for fn in pngs]}
#     with open('Detuning_VS_PulseDur.json', 'w+') as f:
#         json.dump(results_dict, f)

#%% load previous results
with open(r'Detuning_VS_PulseDur.json') as f:
    results_dict = json.load(f)


#%% plot
x = results_dict['Detuning d/w']
y = results_dict['Duration T/(pi/nOr)']
label = '1 - Dark State Fraction'   # 'Final Phonon Number'   '1 - Dark State Fraction'
inds = np.argsort(x)
z = np.array(results_dict[label])[inds]

plt.figure(figsize=(7,5))
# plt.title('$1 - P(n=0)$')
im = plt.imshow(z.T,
                    extent = (min(x), max(x), min(y), max(y)),
                    origin = 'lower',
                    cmap = 'Reds',
                    aspect = 'auto',
                    norm = LogNorm(vmin=np.min(z), vmax=np.max(z)))
plt.colorbar(im, orientation='vertical', label='$1 - P(n=0)$')
plt.xlabel('Detuning from carrier $\delta/\omega_{trap}$')
plt.ylabel('Raman Pulse Duration $T/(\pi/\eta\Omega_R$)')
plt.savefig(r'Simulation_Detuning.pdf')
plt.show()
